"""Test the configuration checks that reject some bad compile-time configs.

This tests the output of ``generate_config_checks.py``.
This can also let us verify what we enforce in the manually written
checks in ``<PROJECT>_check_config.h`` and ``<PROJECT>_config.c``.
"""

## Copyright The Mbed TLS Contributors
## SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later

import os
import subprocess
import sys
import unittest
from typing import List, Optional, Pattern, Union


class TestConfigChecks(unittest.TestCase):
    """Unit tests for checks performed via ``<PROJECT>_config.c``.

    This can test the code generated by `config_checks_generator`,
    as well as manually written checks in `check_config.h`.
    """

    # Set this to the path to the source file containing the config checks.
    PROJECT_CONFIG_C = None #type: Optional[str]

    # Project-specific include directories (in addition to /include)
    PROJECT_SPECIFIC_INCLUDE_DIRECTORIES = [] #type: List[str]

    # Increase the length of strings that assertion failures are willing to
    # print. This is useful for failures where the compiler has a lot
    # to say.
    maxDiff = 9999

    def setUp(self) -> None:
        self.cc_output = None #type: Optional[str]

    def tearDown(self) -> None:
        """Log the compiler output to a file, if available and desired.

        This is intended for debugging. It only happens if the environment
        variable UNITTEST_CONFIG_CHECKS_DEBUG is non-empty.
        """
        if os.getenv('UNITTEST_CONFIG_CHECKS_DEBUG'):
            # We set self.cc_output to the compiler output before
            # asserting, and set it to None if all the assertions pass.
            if self.cc_output is not None:
                basename = os.path.splitext(os.path.basename(sys.argv[0]))[0]
                filename = f'{basename}.{self._testMethodName}.out.txt'
                with open(filename, 'w') as out:
                    out.write(self.cc_output)

    def user_config_file_name(self, variant: str) -> str:
        """Construct a unique temporary file name for a user config header."""
        name = os.path.splitext(os.path.basename(sys.argv[0]))[0]
        pid = str(os.getpid())
        oid = str(id(self))
        return f'tmp-user_config_{variant}-{name}-{pid}-{oid}.h'

    def write_user_config(self, variant: str, content: Optional[str]) -> Optional[str]:
        """Write a user configuration file with the given content.

        If content is None, ensure the file does not exist.

        Return None if content is none, otherwise return the file name.
        """
        file_name = self.user_config_file_name(variant)
        if content is None:
            if os.path.exists(file_name):
                os.remove(file_name)
            return None
        if content and not content.endswith('\n'):
            content += '\n'
        with open(file_name, 'w', encoding='utf-8') as out:
            out.write(content)
        return file_name

    def run_with_config_files(self,
                              crypto_user_config_file: Optional[str],
                              mbedtls_user_config_file: Optional[str],
                              extra_options: List[str],
                              ) -> subprocess.CompletedProcess:
        """Run cc with the given user configuration files.

        Return the CompletedProcess object capturing the return code,
        stdout and stderr.
        """
        cmd = [os.getenv('CC', 'cc')]
        if os.getenv('UNITTEST_CONFIG_CHECKS_DEBUG'):
            cmd += ['-dD']
        if crypto_user_config_file is not None:
            cmd.append(f'-DTF_PSA_CRYPTO_USER_CONFIG_FILE="{crypto_user_config_file}"')
        if mbedtls_user_config_file is not None:
            cmd.append(f'-DMBEDTLS_USER_CONFIG_FILE="{mbedtls_user_config_file}"')
        cmd += extra_options
        assert self.PROJECT_CONFIG_C is not None
        cmd += ['-I' + dir for dir in self.PROJECT_SPECIFIC_INCLUDE_DIRECTORIES]
        cmd += ['-Iinclude',
                '-I.',
                '-I' + os.path.dirname(self.PROJECT_CONFIG_C)]
        cmd += ['-o', os.devnull, '-c', self.PROJECT_CONFIG_C]
        return subprocess.run(cmd,
                              check=False,
                              encoding='utf-8',
                              stdout=subprocess.PIPE,
                              stderr=subprocess.PIPE)

    def run_with_config(self,
                        crypto_user_config: Optional[str],
                        mbedtls_user_config: Optional[str] = None,
                        extra_options: Optional[List[str]] = None,
                        ) -> subprocess.CompletedProcess:
        """Run cc with the given content for user configuration files.

        Return the CompletedProcess object capturing the return code,
        stdout and stderr.
        """
        if extra_options is None:
            extra_options = []
        crypto_user_config_file = None
        mbedtls_user_config_file = None
        try:
            # Create temporary files without using tempfile because:
            # 1. Before Python 3.12, tempfile.NamedTemporaryFile does
            #    not have good support for allowing an external program
            #    to access the file on Windows.
            # 2. With a tempfile-provided context, it's awkward to not
            #    create a file optionally (we only do it when xxx_user_config
            #    is not None).
            crypto_user_config_file = \
                self.write_user_config('crypto', crypto_user_config)
            mbedtls_user_config_file = \
                self.write_user_config('mbedtls', mbedtls_user_config)
            cp = self.run_with_config_files(crypto_user_config_file,
                                            mbedtls_user_config_file,
                                            extra_options)
            return cp
        finally:
            if crypto_user_config_file is not None and \
               os.path.exists(crypto_user_config_file):
                os.remove(crypto_user_config_file)
            if mbedtls_user_config_file is not None and \
               os.path.exists(mbedtls_user_config_file):
                os.remove(mbedtls_user_config_file)

    def good_case(self,
                  crypto_user_config: Optional[str],
                  mbedtls_user_config: Optional[str] = None,
                  extra_options: Optional[List[str]] = None,
                  ) -> None:
        """Run cc with the given user config(s). Expect no error.

        Pass extra_options on the command line of cc.
        """
        cp = self.run_with_config(crypto_user_config, mbedtls_user_config,
                                  extra_options=extra_options)
        # Assert the error text before the status. That way, if it fails,
        # we see the unexpected error messages in the test log.
        self.cc_output = cp.stdout
        self.assertEqual(cp.stderr, '')
        self.assertEqual(cp.returncode, 0)
        self.cc_output = None

    def bad_case(self,
                 crypto_user_config: Optional[str],
                 mbedtls_user_config: Optional[str] = None,
                 error: Optional[Union[str, Pattern]] = None,
                 extra_options: Optional[List[str]] = None,
                 ) -> None:
        """Run cc with the given user config(s). Expect errors.

        Pass extra_options on the command line of cc.

        If error is given, the standard error from cc must match this regex.
        """
        cp = self.run_with_config(crypto_user_config, mbedtls_user_config,
                                  extra_options=extra_options)
        self.cc_output = cp.stdout
        if error is not None:
            # Assert the error text before the status. That way, if it fails,
            # we see the unexpected error messages in the test log.
            self.assertRegex(cp.stderr, error)
        self.assertGreater(cp.returncode, 0)
        self.assertLess(cp.returncode, 126)
        self.cc_output = None

    # Nominal case, run first
    def test_01_nominal(self) -> None:
        self.good_case(None)

    # Trivial error case, run second
    def test_02_error(self) -> None:
        self.bad_case('#error "Bad crypto configuration"',
                      error='"Bad crypto configuration"')
